import argparse
from typing import OrderedDict

import numpy as np
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.autograd as autograd
from torchvision import datasets

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
parser = argparse.ArgumentParser(description='Colored MNIST')
parser.add_argument('--hidden_dim', type=int, default=390)
parser.add_argument('--l2_regularizer_weight', type=float,default=0.001)
parser.add_argument('--inner_lr', type=float, default=0.1)
parser.add_argument('--outer_lr', type=float, default=0.001)
parser.add_argument('--dropgrad', type=float, default=0.0)
parser.add_argument('--n_restarts', type=int, default=3)
parser.add_argument('--steps', type=int, default=66)
parser.add_argument('--inner_steps', type=int, default=2)
parser.add_argument('--class_no', type=int, default=2)
parser.add_argument('--split', type=int, default=2)
parser.add_argument('--add_patch', type=bool, default=False)
parser.add_argument('--early_stopping_start', type=int, default=71)

flags = parser.parse_args()
print('Flags:')

final_train_accs = []
final_valid_accs = []
final_test_accs = []

def mean_nll(logits, y):
        return nn.functional.cross_entropy(logits, y)

def mean_accuracy(logits, y):
    preds = torch.argmax(logits, dim=1).float()
    return (preds == y).float().mean()

def pretty_print(*values):
    col_width = 13
    def format_val(v):
        if not isinstance(v, str):
            v = np.array2string(v, precision=5, floatmode='fixed')
        return v.ljust(col_width)
    str_values = [format_val(v) for v in values]
    print("   ".join(str_values))


for k,v in sorted(vars(flags).items()):
    print("\t{}: {}".format(k, v))

for restart in range(flags.n_restarts):
    print("Restart", restart)
    mnist = datasets.MNIST('~/datasets/mnist', train=True, download=True)
    mnist_train = (mnist.data[:50000], mnist.targets[:50000])
    mnist_val = (mnist.data[50000:], mnist.targets[50000:])

    mnist_for_test = datasets.MNIST('~/datasets/mnist', train=False, download=True)
    mnist_test = (mnist_for_test.data, mnist_for_test.targets)

    rng_state = np.random.get_state()
    np.random.shuffle(mnist_train[0].numpy())
    np.random.set_state(rng_state)
    np.random.shuffle(mnist_train[1].numpy())

    def make_environment(images, labels, color_prob, patch_prob, label_prob, n_classes):
        def torch_bernoulli(label_prob, size):
            return (torch.rand(size) < label_prob).float()

        def collapse_labels(labels, n_classes):
            assert n_classes in [2, 3, 5, 10]
            bin_width = 10 // n_classes
            return (labels / bin_width).clamp(max=n_classes - 1)

        def corrupt(labels, n_classes, prob):
            random_number = torch.randint(1, 2, (1,)).item()
            is_corrupt = torch_bernoulli(prob, len(labels)).bool()
            return torch.where(is_corrupt, (labels + 1) % n_classes, labels)

        # Assign a label based on the digit
        labels = collapse_labels(labels, n_classes).float()
        real_labels = labels.long().to(device)
        # *Corrupt* label with probability 0.25 (default)
        labels = corrupt(labels, n_classes, label_prob)
        # Assign a color based on the label; flip the color with probability e
        colors = corrupt(labels, n_classes, color_prob)

        # Apply the color to the image by only giving image in the assigned color channel
        n, h, w = images.size()
        colored_images = torch.zeros((n, n_classes, h, w)).to(images)
        colored_images[torch.tensor(range(n)), colors.long(), :, :] = images

        if patch_prob != None:
            assert n_classes == 2, 'Only support binary classification'
            squares = corrupt(labels, n_classes, patch_prob).long()
            for i, s in enumerate(squares):
                if s == 0:
                    colored_images[i, :, :3, :3] = 255
                elif s == 1:
                    colored_images[i, :, -2:, -2:] = 255

            return {
            'images': (colored_images.float() / 255.).to(device),
            'labels': labels.long().to(device),
            'colors': colors.long().to(device),
            'patches': squares.long().to(device),
            'real': real_labels
        }
    
        else:
            return {
                'images': (colored_images.float() / 255.).to(device),
                'labels': labels.long().to(device),
                'colors': colors.long().to(device),
                'real': real_labels
            }
    
    n_classes = flags.class_no
    s = flags.split

    if flags.add_patch:
        patch_prob1 = 0.1
        patch_prob2 = 0.2
        patch_val = 0.2
        
    else:
        patch_prob1 = None
        patch_prob2 = None
        patch_val = None

    train_envs = [
                make_environment(mnist_train[0][::s], mnist_train[1][::s], 0.1, patch_prob1, 0.25, n_classes),
                make_environment(mnist_train[0][1::s], mnist_train[1][1::s], 0.2, patch_prob2, 0.25, n_classes),
    ]
    valid_env = make_environment(mnist_val[0][::2], mnist_val[1][::2], 0.2, patch_val, 0.25, n_classes)
    

    class MLP(nn.Module):
        def __init__(self, n_classes):
            super(MLP, self).__init__()
            lin1 = nn.Linear(n_classes*28*28, flags.hidden_dim)
            lin2 = nn.Linear(flags.hidden_dim, flags.hidden_dim)
            lin3 = nn.Linear(flags.hidden_dim, n_classes)
            for lin in [lin1, lin2, lin3]:
                nn.init.xavier_uniform_(lin.weight)
                nn.init.zeros_(lin.bias)
            
            self._main = nn.Sequential(lin1, nn.ReLU(True), lin2, nn.ReLU(True))
            self.classifier = lin3

        def forward(self, input):
            out = input.view(input.shape[0], n_classes*28*28)
            out = self._main(out)
            out = self.classifier(out)
            return out

        def functional_forward(self, input, weights):
            out = input.view(input.shape[0], n_classes*28*28)

            out = F.linear(out, weights['_main.0.weight'], weights['_main.0.bias'])
            out = F.relu(out)
            out = F.linear(out, weights['_main.2.weight'], weights['_main.2.bias'])
            out = F.relu(out)
            out = F.linear(out, weights['classifier.weight'], weights['classifier.bias'])
            
            return out

    mlp = MLP(n_classes).to(device)
    
    outer_optimizer = optim.Adam(mlp.parameters(), lr=flags.outer_lr)
    valid_buffer = 100

    pretty_print('step', 'train nll', 'validation nll', 'validation acc')

    weight_list = []
    for env in range(4):
        weight_list.append(torch.tensor(0.25, device=device))
    group_weight = torch.stack(weight_list)

    for step in range(flags.steps):
        outer_loss_list = []
        mlp.train()
        env_no = 0
        for env, env_ in zip(train_envs, train_envs[::-1]):
            fast_weights = OrderedDict([s for s in mlp.named_parameters() if s[1].requires_grad])
            for _ in range(flags.inner_steps):
                logits = mlp.functional_forward(env['images'], fast_weights)
                env['nll'] = mean_nll(logits, env['labels'])    
                env['acc'] = mean_accuracy(logits, env['labels'])
                loss = env['nll'].clone()

                gradients = autograd.grad(loss, fast_weights.values(), create_graph=True)
                fast_weights = OrderedDict(
                        (name, param - flags.inner_lr * (grad * torch.normal(mean=torch.ones_like(grad),
                                                                                        std=torch.ones_like(grad)*flags.dropgrad).to(device)))
                        for ((name, param), grad) in zip(fast_weights.items(), gradients)
                    )

            logits = mlp.functional_forward(env_['images'], fast_weights)
            preds = torch.argmax(logits, dim=1).float()
            
            correct_idx = (preds == env_['labels'])
            wrong_idx = (preds != env_['labels'])

            correct_logits = logits[correct_idx]
            wrong_logits = logits[wrong_idx]

            correct_loss = mean_nll(correct_logits, env_['labels'][correct_idx])
            wrong_loss = mean_nll(wrong_logits, env_['labels'][wrong_idx])

            group_weight[env_no*2] = group_weight[env_no*2] * torch.exp(1e-2 * correct_loss.data)
            group_weight[env_no*2+1] = group_weight[env_no*2+1] * torch.exp(1e-2 * wrong_loss.data)
            env_no += 1

            weight_norm = torch.tensor(0.).to(device)
            for w in fast_weights.values():
                weight_norm += w.norm().pow(2)
            
            correct_loss += flags.l2_regularizer_weight * weight_norm 
            wrong_loss += flags.l2_regularizer_weight * weight_norm 

            outer_loss_list.append(correct_loss)
            outer_loss_list.append(wrong_loss)

        group_weight = group_weight / group_weight.sum()
        total_outer_loss = torch.stack(outer_loss_list)

        total_outer_loss = total_outer_loss @ group_weight

        outer_optimizer.zero_grad()
        total_outer_loss.backward()
        outer_optimizer.step()

        # validation
        with torch.no_grad():
            mlp.eval()
            logits = mlp(valid_env['images'])
            valid_env['nll'] = mean_nll(logits, valid_env['labels'])
            valid_env['acc'] = mean_accuracy(logits, valid_env['labels'])

        if step % 10 == 0:
            pretty_print(
                np.int32(step),
                total_outer_loss.detach().cpu().numpy(),
                valid_env['acc'].detach().cpu().numpy(),
                valid_env['nll'].detach().cpu().numpy(),
            )

        if valid_env['nll'] - valid_buffer > 0.01 and step > flags.early_stopping_start:
            break
        
        valid_buffer = valid_env['nll']

    # Test
    mlp.eval()
    test_env = make_environment(mnist_test[0], mnist_test[1], 0.9, 0.9, 0.25, n_classes)
    
    with torch.no_grad():
        logits = mlp(test_env['images'])
        test_env['nll'] = mean_nll(logits, test_env['labels'])
        test_env['acc'] = mean_accuracy(logits, test_env['labels'])

    print('Final step:', step)
    print("Validation accuracy is {:.4f}".format(valid_env['acc'].item()))
    print("Test accuracy is {:.4f}".format(test_env['acc'].item()))

    final_valid_accs.append(valid_env['acc'].detach().cpu().numpy())
    final_test_accs.append(test_env['acc'].detach().cpu().numpy())

    print('Final valid acc (mean/std across restarts so far):')
    print(np.mean(final_valid_accs), np.std(final_valid_accs))
    print('Final test acc (mean/std across restarts so far):')
    print(np.mean(final_test_accs), np.std(final_test_accs))